-
Notifications
You must be signed in to change notification settings - Fork 288
Support MRL (Matryoshka Representation Learning) #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the addition! Also I like the macros for Prometheus metrics, those look cleaner IMO.
P.S. Guessing that the snapshots were removed because those were not used, right?
Co-authored-by: Alvaro Bartolome <[email protected]>
oh, sorry for the missing context. you're right! looks like these test cases (router/tests) were all commented, so I've removed the snapshots. updated) thanks for the pointing this :) |
@kozistr QQ: what happens if you set normalize=True + mrl dimensions? Otherwise PR looks good. |
@michaelfeil Hi! thanks for double-checking this :) |
Is there something bringing this to mainline? |
core/src/infer.rs
Outdated
metrics::counter!("te_embed_success").increment(1); | ||
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64()); | ||
metrics::histogram!("te_embed_tokenization_duration") | ||
.record(response.metadata.tokenization.as_secs_f64()); | ||
metrics::histogram!("te_embed_queue_duration") | ||
.record(response.metadata.queue.as_secs_f64()); | ||
metrics::histogram!("te_embed_inference_duration") | ||
.record(response.metadata.inference.as_secs_f64()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert those changes ? They don't seem linked to the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted 9131071
core/src/infer.rs
Outdated
permit: OwnedSemaphorePermit, | ||
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { | ||
let start_time = Instant::now(); | ||
|
||
if self.is_splade() && normalize { | ||
let counter = metrics::counter!("te_request_failure", "err" => "model_type"); | ||
counter.increment(1); | ||
metrics::counter!("te_request_failure", "err" => "model_type").increment(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sam here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted 9131071
let message = "`normalize` is not available for SPLADE models".to_string(); | ||
tracing::error!("{message}"); | ||
return Err(TextEmbeddingsError::Backend(BackendError::Inference( | ||
message, | ||
))); | ||
} | ||
|
||
if let Some(dimensions) = dimensions { | ||
if dimensions == 0 { | ||
metrics::counter!("te_request_failure", "err" => "validation").increment(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't they also be smaller than the maximum embedding dimension ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At that time, I was considering silently returning the embedding with the original size when the given dimension is larger than the expected size, like line 274.
On second thought, like you mentioned, it'd be better to raise an error explicitly when the size is larger than expected in terms of validity.
I'll add an extra validation logic to check whether the given size is larger than the size of the embedding. thanks for catching this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a1b1a26
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
What does this PR do?
Fixes #673
Add
dimensions
field to the embed request API spec.router/tests
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil @alvarobartt